# This code is part of Qiskit.
#
# (C) Copyright IBM 2021.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""The Active-Space Reduction interface."""

from typing import List, Optional, Tuple, Union
import copy
import logging
import numpy as np

from qiskit_nature import QiskitNatureError
from qiskit_nature.deprecation import DeprecatedType, warn_deprecated_same_type_name
from qiskit_nature.drivers import QMolecule

from .base_transformer import BaseTransformer

logger = logging.getLogger(__name__)


class ActiveSpaceTransformer(BaseTransformer):
    r"""**DEPRECATED!** The Active-Space reduction.

    The reduction is done by computing the inactive Fock operator which is defined as
    :math:`F^I_{pq} = h_{pq} + \sum_i 2 g_{iipq} - g_{iqpi}` and the inactive energy which is
    given by :math:`E^I = \sum_j h_{jj} + F ^I_{jj}`, where :math:`i` and :math:`j` iterate over
    the inactive orbitals.
    By using the inactive Fock operator in place of the one-electron integrals, `h1`, the
    description of the active space contains an effective potential generated by the inactive
    electrons. Therefore, this method permits the exclusion of non-core electrons while
    retaining a high-quality description of the system.

    For more details on the computation of the inactive Fock operator refer to
    https://arxiv.org/abs/2009.01872.

    The active space can be configured in one of the following ways through the initializer:
        - when only `num_electrons` and `num_molecular_orbitals` are specified, these integers
          indicate the number of active electrons and orbitals, respectively. The active space will
          then be chosen around the Fermi level resulting in a unique choice for any pair of
          numbers.  Nonetheless, the following criteria must be met:

            #. the remaining number of inactive electrons must be a positive, even number

            #. the number of active orbitals must not exceed the total number of orbitals minus the
               number of orbitals occupied by the inactive electrons

        - when, in addition to the above, `num_alpha` is specified, this can be used to disambiguate
          the active space in systems with non-zero spin. Thus, `num_alpha` determines the number of
          active alpha electrons. The number of active beta electrons can then be determined based
          via `num_beta = num_electrons - num_alpha`. The same requirements as listed in the
          previous case must be met.
        - finally, it is possible to select a custom set of active orbitals via their indices using
          `active_orbitals`. This allows selecting an active space which is not placed around the
          Fermi level as described in the first case, above. When using this keyword argument, the
          following criteria must be met *in addition* to the ones listed above:

            #. the length of `active_orbitals` must be equal to `num_molecular_orbitals`.

            #. the sum of electrons present in `active_orbitals` must be equal to `num_electrons`.

    References:
        - *M. Rossmannek, P. Barkoutsos, P. Ollitrault, and I. Tavernelli, arXiv:2009.01872
          (2020).*
    """

    def __init__(
        self,
        num_electrons: Optional[Union[int, Tuple[int, int]]] = None,
        num_molecular_orbitals: Optional[int] = None,
        active_orbitals: Optional[List[int]] = None,
    ):
        """Initializes a transformer which can reduce a `QMolecule` to a configured active space.

        This transformer requires the AO-basis matrices `hcore` and `eri` to be available, as well
        as the basis-transformation matrix `mo_coeff`. A `QMolecule` produced by Qiskit's drivers in
        general satisfies these conditions unless it was read from an FCIDump file. However, those
        integrals are likely already reduced by the code which produced the file or can be
        transformed using this driver after copying the MO-basis integrals of the produced
        `QMolecule` into the AO-basis containers and initializing `mo_coeff` with an identity matrix
        of appropriate size.

        Args:
            num_electrons: The number of active electrons. If this is a tuple, it represents the
                           number of alpha and beta electrons. If this is a number, it is
                           interpreted as the total number of active electrons, should be even, and
                           implies that the number of alpha and beta electrons equals half of this
                           value, respectively.
            num_molecular_orbitals: The number of active orbitals.
            active_orbitals: A list of indices specifying the molecular orbitals of the active
                             space. This argument must match with the remaining arguments and should
                             only be used to enforce an active space that is not chosen purely
                             around the Fermi level.
        """
        warn_deprecated_same_type_name(
            "0.2.0",
            DeprecatedType.CLASS,
            "ActiveSpaceTransformer",
            "from qiskit_nature.transformers.second_quantization.electronic as a direct replacement",
        )
        self._num_electrons = num_electrons
        self._num_molecular_orbitals = num_molecular_orbitals
        self._active_orbitals = active_orbitals

        self._beta: bool = None
        self._mo_occ_total: np.ndarray = None
        self._mo_occ_inactive: Tuple[np.ndarray, np.ndarray] = None
        self._mo_coeff_active: Tuple[np.ndarray, np.ndarray] = None
        self._mo_coeff_inactive: Tuple[np.ndarray, np.ndarray] = None
        self._density_inactive: Tuple[np.ndarray, np.ndarray] = None
        self._num_particles: Tuple[int, int] = None

    def transform(self, molecule_data: QMolecule) -> QMolecule:
        """Reduces the given `QMolecule` to a given active space.

        Args:
            molecule_data: the `QMolecule` to be transformed.

        Returns:
            A new `QMolecule` instance.

        Raises:
            QiskitNatureError: If more electrons or orbitals are requested than are available, if an
                               uneven number of inactive electrons remains, or if the number of
                               selected active orbital indices does not match
                               `num_molecular_orbitals`.
        """
        try:
            self._check_configuration()
        except QiskitNatureError as exc:
            raise QiskitNatureError("Incorrect Active-Space configuration.") from exc

        # get molecular orbital coefficients
        mo_coeff_full = (molecule_data.mo_coeff, molecule_data.mo_coeff_b)
        self._beta = mo_coeff_full[1] is not None
        # get molecular orbital occupation numbers
        mo_occ_full = self._extract_mo_occupation_vector(molecule_data)
        self._mo_occ_total = mo_occ_full[0] + mo_occ_full[1] if self._beta else mo_occ_full[0]

        active_orbs_idxs, inactive_orbs_idxs = self._determine_active_space(molecule_data)

        # split molecular orbitals coefficients into active and inactive parts
        self._mo_coeff_inactive = (
            mo_coeff_full[0][:, inactive_orbs_idxs],
            mo_coeff_full[1][:, inactive_orbs_idxs] if self._beta else None,
        )
        self._mo_coeff_active = (
            mo_coeff_full[0][:, active_orbs_idxs],
            mo_coeff_full[1][:, active_orbs_idxs] if self._beta else None,
        )
        self._mo_occ_inactive = (
            mo_occ_full[0][inactive_orbs_idxs],
            mo_occ_full[1][inactive_orbs_idxs] if self._beta else None,
        )

        self._compute_inactive_density_matrix()

        # construct new QMolecule
        molecule_data_reduced = copy.deepcopy(molecule_data)
        # Energies and orbitals
        molecule_data_reduced.num_molecular_orbitals = self._num_molecular_orbitals
        molecule_data_reduced.num_alpha = self._num_particles[0]
        molecule_data_reduced.num_beta = self._num_particles[1]
        molecule_data_reduced.mo_coeff = self._mo_coeff_active[0]
        molecule_data_reduced.mo_coeff_b = self._mo_coeff_active[1]
        molecule_data_reduced.orbital_energies = molecule_data.orbital_energies[active_orbs_idxs]
        if self._beta:
            molecule_data_reduced.orbital_energies_b = molecule_data.orbital_energies_b[
                active_orbs_idxs
            ]
        molecule_data_reduced.kinetic = None
        molecule_data_reduced.overlap = None

        # reduce electronic energy integrals
        self._reduce_to_active_space(
            molecule_data,
            molecule_data_reduced,
            "energy_shift",
            ("hcore", "hcore_b"),
            ("mo_onee_ints", "mo_onee_ints_b"),
            "eri",
            ("mo_eri_ints", "mo_eri_ints_ba", "mo_eri_ints_bb"),
        )

        # reduce dipole moment integrals
        if molecule_data.has_dipole_integrals():
            self._reduce_to_active_space(
                molecule_data,
                molecule_data_reduced,
                "x_dip_energy_shift",
                ("x_dip_ints", None),
                ("x_dip_mo_ints", "x_dip_mo_ints_b"),
            )
            self._reduce_to_active_space(
                molecule_data,
                molecule_data_reduced,
                "y_dip_energy_shift",
                ("y_dip_ints", None),
                ("y_dip_mo_ints", "y_dip_mo_ints_b"),
            )
            self._reduce_to_active_space(
                molecule_data,
                molecule_data_reduced,
                "z_dip_energy_shift",
                ("z_dip_ints", None),
                ("z_dip_mo_ints", "z_dip_mo_ints_b"),
            )

        return molecule_data_reduced

    def _check_configuration(self):
        if isinstance(self._num_electrons, int):
            if self._num_electrons % 2 != 0:
                raise QiskitNatureError(
                    "The number of active electrons must be even! Otherwise you must specify them "
                    "as a tuple, not as:",
                    str(self._num_electrons),
                )
            if self._num_electrons < 0:
                raise QiskitNatureError(
                    "The number of active electrons cannot be negative, not:",
                    str(self._num_electrons),
                )
        elif isinstance(self._num_electrons, tuple):
            if not all(isinstance(n_elec, int) and n_elec >= 0 for n_elec in self._num_electrons):
                raise QiskitNatureError(
                    "Neither the number of alpha, nor the number of beta electrons can be "
                    "negative, not:",
                    str(self._num_electrons),
                )
        else:
            raise QiskitNatureError(
                "The number of active electrons must be an int, or a tuple thereof, not:",
                str(self._num_electrons),
            )

        if isinstance(self._num_molecular_orbitals, int):
            if self._num_molecular_orbitals < 0:
                raise QiskitNatureError(
                    "The number of active orbitals cannot be negative, not:",
                    str(self._num_molecular_orbitals),
                )
        else:
            raise QiskitNatureError(
                "The number of active orbitals must be an int, not:",
                str(self._num_electrons),
            )

    def _extract_mo_occupation_vector(self, molecule_data: QMolecule):
        mo_occ_full = (molecule_data.mo_occ, molecule_data.mo_occ_b)
        if mo_occ_full[0] is None:
            # QMolecule provided by driver without `mo_occ` information available. Constructing
            # occupation numbers based on ground state HF case.
            occ_alpha = [1.0] * molecule_data.num_alpha + [0.0] * (
                molecule_data.num_molecular_orbitals - molecule_data.num_alpha
            )
            if self._beta:
                occ_beta = [1.0] * molecule_data.num_beta + [0.0] * (
                    molecule_data.num_molecular_orbitals - molecule_data.num_beta
                )
            else:
                occ_alpha[: molecule_data.num_beta] = [
                    o + 1 for o in occ_alpha[: molecule_data.num_beta]
                ]
                occ_beta = None
            mo_occ_full = (np.asarray(occ_alpha), np.asarray(occ_beta))
        return mo_occ_full

    def _determine_active_space(self, molecule_data: QMolecule):
        if isinstance(self._num_electrons, tuple):
            num_alpha, num_beta = self._num_electrons
        elif isinstance(self._num_electrons, int):
            num_alpha = num_beta = self._num_electrons // 2

        # compute number of inactive electrons
        nelec_total = molecule_data.num_alpha + molecule_data.num_beta
        nelec_inactive = nelec_total - num_alpha - num_beta

        self._num_particles = (num_alpha, num_beta)

        self._validate_num_electrons(nelec_inactive)
        self._validate_num_orbitals(nelec_inactive, molecule_data)

        # determine active and inactive orbital indices
        if self._active_orbitals is None:
            norbs_inactive = nelec_inactive // 2
            inactive_orbs_idxs = list(range(norbs_inactive))
            active_orbs_idxs = list(
                range(norbs_inactive, norbs_inactive + self._num_molecular_orbitals)
            )
        else:
            active_orbs_idxs = self._active_orbitals
            inactive_orbs_idxs = [
                o
                for o in range(nelec_total // 2)
                if o not in self._active_orbitals and self._mo_occ_total[o] > 0
            ]

        return (active_orbs_idxs, inactive_orbs_idxs)

    def _validate_num_electrons(self, nelec_inactive: int):
        """Validates the number of electrons.

        Args:
            nelec_inactive: the computed number of inactive electrons.

        Raises:
            QiskitNatureError: if the number of inactive electrons is either negative or odd.
        """
        if nelec_inactive < 0:
            raise QiskitNatureError("More electrons requested than available.")
        if nelec_inactive % 2 != 0:
            raise QiskitNatureError("The number of inactive electrons must be even.")

    def _validate_num_orbitals(self, nelec_inactive: int, molecule_data: QMolecule):
        """Validates the number of orbitals.

        Args:
            nelec_inactive: the computed number of inactive electrons.
            molecule_data: the `QMolecule` to be transformed.

        Raises:
            QiskitNatureError: if more orbitals were requested than are available in total or if the
                               number of selected orbitals mismatches the specified number of active
                               orbitals.
        """
        if self._active_orbitals is None:
            norbs_inactive = nelec_inactive // 2
            if norbs_inactive + self._num_molecular_orbitals > molecule_data.num_molecular_orbitals:
                raise QiskitNatureError("More orbitals requested than available.")
        else:
            if self._num_molecular_orbitals != len(self._active_orbitals):
                raise QiskitNatureError(
                    "The number of selected active orbital indices does not "
                    "match the specified number of active orbitals."
                )
            if max(self._active_orbitals) >= molecule_data.num_molecular_orbitals:
                raise QiskitNatureError("More orbitals requested than available.")
            if sum(self._mo_occ_total[self._active_orbitals]) != self._num_electrons:
                raise QiskitNatureError(
                    "The number of electrons in the selected active orbitals "
                    "does not match the specified number of active electrons."
                )

    def _compute_inactive_density_matrix(self):
        """Computes the inactive density matrix."""
        density_inactive_a = np.dot(
            self._mo_coeff_inactive[0] * self._mo_occ_inactive[0],
            np.transpose(self._mo_coeff_inactive[0]),
        )
        density_inactive_b = None
        if self._beta:
            density_inactive_b = np.dot(
                self._mo_coeff_inactive[1] * self._mo_occ_inactive[1],
                np.transpose(self._mo_coeff_inactive[1]),
            )
        self._density_inactive = (density_inactive_a, density_inactive_b)

    def _reduce_to_active_space(
        self,
        molecule_data: QMolecule,
        molecule_data_reduced: QMolecule,
        energy_shift_attribute: str,
        ao_1e_attribute: Tuple[str, Optional[str]],
        mo_1e_attribute: Tuple[str, str],
        ao_2e_attribute: Optional[str] = None,
        mo_2e_attribute: Optional[Tuple[str, str, str]] = None,
    ) -> None:
        """A utility method which performs the actual orbital reduction computation.

        Args:
            molecule_data: the original `QMolecule` object.
            molecule_data_reduced: the reduced `QMolecule` object.
            energy_shift_attribute: the name of the attribute which stores the energy shift.
            ao_1e_attribute: the names of the AO-basis 1-electron matrices.
            mo_1e_attribute: the names of the MO-basis 1-electron matrices.
            ao_2e_attribute: the name of the AO-basis 2-electron matrix.
            mo_2e_attribute: the names of the MO-basis 2-electron matrices.
        """
        ao_1e_matrix_a = getattr(molecule_data, ao_1e_attribute[0])

        if self._beta:
            ao_1e_matrix_b = None
            if ao_1e_attribute[1] is not None:
                # It is possible that no beta-spin one-electron AO matrices are available (e.g. for
                # dipole integrals). Then, this attribute name will be None.
                ao_1e_matrix_b = getattr(molecule_data, ao_1e_attribute[1])
            if ao_1e_matrix_b is None:
                # Furthermore, even if the attribute name was present, the object itself may be
                # None. In that case we will simply use the alpha-spin pendant.
                ao_1e_matrix_b = ao_1e_matrix_a
        else:
            ao_1e_matrix_b = None

        ao_1e_matrix = (ao_1e_matrix_a, ao_1e_matrix_b)

        if ao_2e_attribute:
            ao_2e_matrix = getattr(molecule_data, ao_2e_attribute)
        else:
            ao_2e_matrix = None

        if ao_2e_matrix is None:
            # no 2-electron AO matrix is given
            inactive_op = copy.deepcopy(ao_1e_matrix)
        else:
            inactive_op = self._compute_inactive_fock_op(ao_1e_matrix, ao_2e_matrix)

        if self._beta and inactive_op[1] is None:
            # To a similar reasoning as for ao_1e_matrix_b, it is possible that this object is None.
            # If that is the case, we fallback to using the alpha-spin pendant.
            inactive_op = (inactive_op[0], inactive_op[0])

        energy_shift = self._compute_inactive_energy(ao_1e_matrix, inactive_op)

        mo_1e_matrix, mo_2e_matrix = self._compute_active_integrals(inactive_op, ao_2e_matrix)

        getattr(molecule_data_reduced, energy_shift_attribute)[
            "ActiveSpaceTransformer"
        ] = energy_shift
        setattr(molecule_data_reduced, ao_1e_attribute[0], inactive_op[0])
        setattr(molecule_data_reduced, mo_1e_attribute[0], mo_1e_matrix[0])
        if self._beta:
            if ao_1e_attribute[1] is not None:
                setattr(molecule_data_reduced, ao_1e_attribute[1], inactive_op[1])
            setattr(molecule_data_reduced, mo_1e_attribute[1], mo_1e_matrix[1])
        if mo_2e_matrix is not None:
            setattr(molecule_data_reduced, mo_2e_attribute[0], mo_2e_matrix[0])
            if self._beta:
                setattr(molecule_data_reduced, mo_2e_attribute[1], mo_2e_matrix[1])
                setattr(molecule_data_reduced, mo_2e_attribute[2], mo_2e_matrix[2])

    def _compute_inactive_fock_op(
        self,
        hcore: Tuple[np.ndarray, Optional[np.ndarray]],
        eri: np.ndarray,
    ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
        """Computes the inactive Fock operator.

        Args:
            hcore: the alpha- and beta-spin core Hamiltonian pair.
            eri: the electron-repulsion-integrals in MO format.

        Returns:
            The pair of alpha- and beta-spin inactive Fock operators.
        """
        # compute inactive Fock matrix
        coulomb_inactive = np.einsum("ijkl,ji->kl", eri, self._density_inactive[0])
        exchange_inactive = np.einsum("ijkl,jk->il", eri, self._density_inactive[0])
        fock_inactive = hcore[0] + coulomb_inactive - 0.5 * exchange_inactive
        fock_inactive_b = coulomb_inactive_b = exchange_inactive_b = None

        if self._beta:
            coulomb_inactive_b = np.einsum("ijkl,ji->kl", eri, self._density_inactive[1])
            exchange_inactive_b = np.einsum("ijkl,jk->il", eri, self._density_inactive[1])
            fock_inactive = hcore[0] + coulomb_inactive + coulomb_inactive_b - exchange_inactive
            fock_inactive_b = hcore[1] + coulomb_inactive + coulomb_inactive_b - exchange_inactive_b

        return (fock_inactive, fock_inactive_b)

    def _compute_inactive_energy(
        self,
        hcore: Tuple[np.ndarray, Optional[np.ndarray]],
        fock_inactive: Tuple[np.ndarray, Optional[np.ndarray]],
    ) -> float:
        """Computes the inactive energy.

        Args:
            hcore: the alpha- and beta-spin core Hamiltonian pair.
            fock_inactive: the alpha- and beta-spin inactive fock operator pair.

        Returns:
            The inactive energy.
        """
        # compute inactive energy
        e_inactive = 0.0
        if not self._beta and self._mo_coeff_inactive[0].size > 0:
            e_inactive += 0.5 * np.einsum(
                "ij,ji", self._density_inactive[0], hcore[0] + fock_inactive[0]
            )
        elif self._beta and self._mo_coeff_inactive[1].size > 0:
            e_inactive += 0.5 * np.einsum(
                "ij,ji", self._density_inactive[0], hcore[0] + fock_inactive[0]
            )
            e_inactive += 0.5 * np.einsum(
                "ij,ji", self._density_inactive[1], hcore[1] + fock_inactive[1]
            )

        return e_inactive

    def _compute_active_integrals(
        self,
        fock_inactive: Tuple[np.ndarray, Optional[np.ndarray]],
        eri: Optional[np.ndarray] = None,
    ) -> Tuple[
        Tuple[np.ndarray, Optional[np.ndarray]],
        Optional[Tuple[np.ndarray, Optional[np.ndarray], Optional[np.ndarray]]],
    ]:
        """Computes the h1 and h2 integrals for the active space.

        Args:
            fock_inactive: the alpha- and beta-spin inactive fock operator pair.
            eri: the electron-repulsion-integrals in MO format.

        Returns:
            The h1 and h2 integrals for the active space. The storage format is the following:
                ((alpha-spin h1, beta-spin h1),
                 (alpha-alpha-spin h2, beta-alpha-spin h2, beta-beta-spin h2))
        """
        # compute new 1- and 2-electron integrals
        hij = np.dot(
            np.dot(np.transpose(self._mo_coeff_active[0]), fock_inactive[0]),
            self._mo_coeff_active[0],
        )
        hij_b = None
        if self._beta:
            hij_b = np.dot(
                np.dot(np.transpose(self._mo_coeff_active[1]), fock_inactive[1]),
                self._mo_coeff_active[1],
            )

        if eri is None:
            return ((hij, hij_b), None)

        hijkl = np.einsum(
            "pqrs,pi,qj,rk,sl->ijkl",
            eri,
            self._mo_coeff_active[0],
            self._mo_coeff_active[0],
            self._mo_coeff_active[0],
            self._mo_coeff_active[0],
            optimize=True,
        )

        hijkl_bb = hijkl_ba = None

        if self._beta:
            hijkl_bb = np.einsum(
                "pqrs,pi,qj,rk,sl->ijkl",
                eri,
                self._mo_coeff_active[1],
                self._mo_coeff_active[1],
                self._mo_coeff_active[1],
                self._mo_coeff_active[1],
                optimize=True,
            )
            hijkl_ba = np.einsum(
                "pqrs,pi,qj,rk,sl->ijkl",
                eri,
                self._mo_coeff_active[1],
                self._mo_coeff_active[1],
                self._mo_coeff_active[0],
                self._mo_coeff_active[0],
                optimize=True,
            )

        return (hij, hij_b), (hijkl, hijkl_ba, hijkl_bb)
